# -*- coding: utf-8 -*-
'''
Created on 2017年5月8日

@author: ZhuJiahui506
'''

import os
import time
import numpy as np
import tensorflow as tf
from tensorflow.contrib import rnn
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib.learn.python.learn.ops.seq2seq_ops import rnn_decoder

# 设置 GPU 按需增长
config = tf.ConfigProto()
config.gpu_options.allow_growth = True

def test1():
    # 首先导入数据，看一下数据的形式
    mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
    
    lr = 1e-3
    # 在训练和测试的时候，我们想用不同的 batch_size.所以采用占位符的方式
    batch_size = tf.placeholder(tf.int32)  # 注意类型必须为 tf.int32
    # batch_size = 128
    
    # 每个时刻的输入特征是28维的，就是每个时刻输入一行，一行有 28 个像素
    input_size = 28
    # 时序持续长度为28，即每做一次预测，需要先输入28行
    timestep_size = 28
    # 每个隐含层的节点数
    hidden_size = 256
    # LSTM layer 的层数
    layer_num = 2
    # 最后输出分类类别数量，如果是回归预测的话应该是 1
    class_num = 10
    
    _X = tf.placeholder(tf.float32, [None, 784])
    y = tf.placeholder(tf.float32, [None, class_num])
    keep_prob = tf.placeholder(tf.float32)
    
    # 把784个点的字符信息还原成 28 * 28 的图片
    # 下面几个步骤是实现 RNN / LSTM 的关键
    ####################################################################
    # **步骤1：RNN 的输入shape = (batch_size, timestep_size, input_size) 
    X = tf.reshape(_X, [-1, timestep_size, input_size])
    
    # **步骤2：定义一层 LSTM_cell，只需要说明 hidden_size, 它会自动匹配输入的 X 的维度
    lstm_cell = rnn.BasicLSTMCell(num_units=hidden_size, forget_bias=1.0, state_is_tuple=True)
    
    # **步骤3：添加 dropout layer, 一般只设置 output_keep_prob
    lstm_cell = rnn.DropoutWrapper(cell=lstm_cell, input_keep_prob=1.0, output_keep_prob=keep_prob)
    
    # **步骤4：调用 MultiRNNCell 来实现多层 LSTM
    mlstm_cell = rnn.MultiRNNCell([lstm_cell] * layer_num, state_is_tuple=True)
    
    # **步骤5：用全零来初始化state
    init_state = mlstm_cell.zero_state(batch_size, dtype=tf.float32)
    
    # **步骤6：方法一，调用 dynamic_rnn() 来让我们构建好的网络运行起来
    # ** 当 time_major==False 时， outputs.shape = [batch_size, timestep_size, hidden_size] 
    # ** 所以，可以取 h_state = outputs[:, -1, :] 作为最后输出
    # ** state.shape = [layer_num, 2, batch_size, hidden_size], 
    # ** 或者，可以取 h_state = state[-1][1] 作为最后输出
    # ** 最后输出维度是 [batch_size, hidden_size]
    outputs, state = tf.nn.dynamic_rnn(mlstm_cell, inputs=X, initial_state=init_state, time_major=False)
    h_state = outputs[:, -1, :]  # 或者 h_state = state[-1][1]
    
    # 上面 LSTM 部分的输出会是一个 [hidden_size] 的tensor，我们要分类的话，还需要接一个 softmax 层
    # 首先定义 softmax 的连接权重矩阵和偏置
    # 开始训练和测试
    W = tf.Variable(tf.truncated_normal([hidden_size, class_num], stddev=0.1), dtype=tf.float32)
    bias = tf.Variable(tf.constant(0.1,shape=[class_num]), dtype=tf.float32)
    y_pre = tf.nn.softmax(tf.matmul(h_state, W) + bias)
    
    # 损失和评估函数
    cross_entropy = -tf.reduce_mean(y * tf.log(y_pre))
    optimizer = tf.train.AdamOptimizer(lr).minimize(cross_entropy)
    
    correct_prediction = tf.equal(tf.argmax(y_pre,1), tf.argmax(y,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
    
    with tf.Session(config=config) as sess:
        tf.global_variables_initializer().run()
        for i in range(2000):
            _batch_size = 128
            batch = mnist.train.next_batch(_batch_size)
            if (i + 1) % 200 == 0:
                train_accuracy = sess.run(accuracy, feed_dict={
                    _X: batch[0], 
                    y: batch[1], 
                    keep_prob: 1.0, 
                    batch_size: _batch_size
                })
                # 已经迭代完成的 epoch 数: mnist.train.epochs_completed
                print("Iter%d, step %d, training accuracy %g" % (mnist.train.epochs_completed, (i + 1), train_accuracy))
            
            # optimizer无返回值
            sess.run(optimizer, feed_dict={
                _X: batch[0], 
                y: batch[1], 
                keep_prob: 0.5, 
                batch_size: _batch_size
            })
    
        # 计算测试数据的准确率
        test_accuracy = sess.run(accuracy, feed_dict={
            _X: mnist.test.images, 
            y: mnist.test.labels, 
            keep_prob: 1.0, 
            batch_size: mnist.test.images.shape[0]
        })
        print("test accuracy %g" % test_accuracy)


def test2():
    mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
    
    # 设置对应的训练参数
    learning_rate = 0.01
    max_samples = 100000
    batch_size = 128
    display_step = 10
    
    n_input = 28
    n_steps = 28
    n_hidden = 256
    n_classes = 10
    
    # 创建输入x和学习目标y的placeholder，这里我们的样本被理解为一个时间序列，第一个维度是时间点n_step，第二个维度是每个时间点的数据n_inpt
    x = tf.placeholder("float", [None, n_steps, n_input])
    y = tf.placeholder("float", [None, n_classes])
    
    # 同时，在最后创建Softmax层的权重和偏差
    weights = tf.Variable(tf.random_normal([2 * n_hidden, n_classes]))
    biases = tf.Variable(tf.random_normal([n_classes]))
    
    # 双向LSTM层
    new_x = tf.transpose(x, [1, 0, 2])
    new_x = tf.reshape(new_x, [-1, n_input])
    new_x = tf.split(new_x, n_steps)
    print(new_x)

    lstm_fw_cell = rnn.BasicLSTMCell(n_hidden, forget_bias = 1.0)
    lstm_bw_cell = rnn.BasicLSTMCell(n_hidden, forget_bias = 1.0)

    outputs, _, _ = rnn.static_bidirectional_rnn(lstm_fw_cell,
                                                            lstm_bw_cell, new_x,
                                                            dtype = tf.float32)
    print(outputs[-1])
    pred = tf.matmul(outputs[-1], weights) + biases
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
    
    # 准确率计算
    correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
    
    with tf.Session(config=config) as sess:
        tf.global_variables_initializer().run()
        
        step = 1
        while step * batch_size < max_samples:
            batch_x, batch_y = mnist.train.next_batch(batch_size)
            batch_x = batch_x.reshape((batch_size, n_steps, n_input))

            sess.run(optimizer, feed_dict = {x: batch_x, y: batch_y})
            if step % display_step == 0:
                acc = sess.run(accuracy, feed_dict = {x: batch_x, y: batch_y})
                loss = sess.run(cost, feed_dict = {x: batch_x, y: batch_y})
                print("Iter" + str(step * batch_size) + ", Minibatch Loss = " + \
                    "{:.6f}".format(loss) + ", Training Accuracy = " + \
                    "{:.5f}".format(acc))
            step += 1
        
        test_len = 10000
        test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input))
        test_label = mnist.test.labels[:test_len]
        print("Testing Accuracy:", sess.run(accuracy, feed_dict = {x: test_data, y: test_label}))


def BiRNN(_X, n_hidden, n_steps, n_input, _istate_fw, _istate_bw, _weights, _biases, _batch_size, _seq_len):

    # BiRNN requires to supply sequence_length as [batch_size, int64]
    # Note: Tensorflow 0.6.0 requires BiRNN sequence_length parameter to be set
    # For a better implementation with latest version of tensorflow, check below
    _seq_len = tf.fill([_batch_size], tf.constant(_seq_len, dtype=tf.int64))

    # input shape: (batch_size, n_steps, n_input)
    _X = tf.transpose(_X, [1, 0, 2])  # permute n_steps and batch_size
    # Reshape to prepare input to hidden activation
    _X = tf.reshape(_X, [-1, n_input]) # (n_steps*batch_size, n_input)
    # Linear activation
    _X = tf.matmul(_X, _weights['hidden']) + _biases['hidden']

    # Define lstm cells with tensorflow
    # Forward direction cell
    lstm_fw_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
    # Backward direction cell
    lstm_bw_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
    # Split data because rnn cell needs a list of inputs for the RNN inner loop
    _X = tf.split(_X, n_steps) # n_steps * (batch_size, n_hidden)
    
    # Get lstm cell output
    print(_X)
    outputs, _, _ = rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, _X,
                                           dtype = tf.float32,
                                           initial_state_fw=None,
                                           initial_state_bw=None)

    # Linear activation
    # Get inner loop last output
    return tf.matmul(outputs[-1], _weights['out']) + _biases['out']


def test3():
    '''
    https://my.oschina.net/yilian/blog/667904
    '''
    
    mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
    
    # Parameters
    learning_rate = 0.001
    training_iters = 100000
    batch_size = 128
    display_step = 10
    
    # Network Parameters
    n_input = 28 # MNIST data input (img shape: 28*28)
    n_steps = 28 # timesteps
    n_hidden = 128 # hidden layer num of features
    n_classes = 10 # MNIST total classes (0-9 digits)
    
    # tf Graph input
    x = tf.placeholder(tf.float32, [None, n_steps, n_input])
    # Tensorflow LSTM cell requires 2x n_hidden length (state & cell)
    istate_fw = tf.placeholder(tf.float32, [None, 2*n_hidden])
    istate_bw = tf.placeholder(tf.float32, [None, 2*n_hidden])
    y = tf.placeholder(tf.float32, [None, n_classes])
    
    # Define weights
    weights = {
        # Hidden layer weights => 2*n_hidden because of foward + backward cells
        'hidden': tf.Variable(tf.random_normal([n_input, 2*n_hidden])),
        'out': tf.Variable(tf.random_normal([2*n_hidden, n_classes]))
    }
    biases = {
        'hidden': tf.Variable(tf.random_normal([2*n_hidden])),
        'out': tf.Variable(tf.random_normal([n_classes]))
    }
    
    pred = BiRNN(x, n_hidden, n_steps, n_input, istate_fw, istate_bw, weights, biases, batch_size, n_steps)

    # Define loss and optimizer
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y)) # Softmax loss
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost) # Adam Optimizer
    
    # Evaluate model
    correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
    
    # Initializing the variables
    
    # Launch the graph
    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        
        step = 1
        # Keep training until reach max iterations
        while step * batch_size < training_iters:
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            # Reshape data to get 28 seq of 28 elements
            batch_xs = batch_xs.reshape((batch_size, n_steps, n_input))
            # Fit training using batch data
            sess.run(optimizer, feed_dict={x: batch_xs, y: batch_ys,
                                           istate_fw: np.zeros((batch_size, 2*n_hidden)),
                                           istate_bw: np.zeros((batch_size, 2*n_hidden))})
            if step % display_step == 0:
                # Calculate batch accuracy
                acc = sess.run(accuracy, feed_dict={x: batch_xs, y: batch_ys,
                                                    istate_fw: np.zeros((batch_size, 2*n_hidden)),
                                                    istate_bw: np.zeros((batch_size, 2*n_hidden))})
                # Calculate batch loss
                loss = sess.run(cost, feed_dict={x: batch_xs, y: batch_ys,
                                                 istate_fw: np.zeros((batch_size, 2*n_hidden)),
                                                 istate_bw: np.zeros((batch_size, 2*n_hidden))})
                print("Iter " + str(step*batch_size) + ", Minibatch Loss= " + "{:.6f}".format(loss) + \
                      ", Training Accuracy= " + "{:.5f}".format(acc))
            step += 1
        print("Optimization Finished!")
        # Calculate accuracy for 128 mnist test images
        test_len = 128
        test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input))
        test_label = mnist.test.labels[:test_len]
        print("Testing Accuracy:", sess.run(accuracy, feed_dict={
            x: test_data, 
            y: test_label,
            istate_fw: np.zeros((test_len, 2*n_hidden)),
            istate_bw: np.zeros((test_len, 2*n_hidden))
        }))
    


if __name__ == '__main__':
    # test1()
    test2()
    # test3()
    
    